package tree.simple;

import tree.TreeNode;

public class BinaryTreeTilt_563 {
    int ans = 0;
    public int findTilt(TreeNode root) {
        dfs(root);
        return ans;
    }

    public int dfs(TreeNode node){
        if (node == null) {
            return 0;
        }
        int left = dfs(node.left);
        int right = dfs(node.right);
        ans += Math.abs(left - right);
        return left + right + node.val;
    }
}
